package area


import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{DataType, StringType, StructField, StructType}

class CityInfoSUDAF extends UserDefinedAggregateFunction{
  //1. inputSchema：函数的参数列表，不过需要写成StructType的格式
  override def inputSchema: StructType = StructType(Array(StructField("cityInfo",StringType)))

  //2.bufferSchema：中间结果的类型
  override def bufferSchema: StructType = StructType(Array(StructField("cityInfo",StringType)))

  //3.dataType：返回值结果类型，显示是DataType
  override def dataType: DataType = StringType

  //4.deterministic: 结果是否是确定性的，即相同的输入，是否一定会有相同的输出。
  override def deterministic: Boolean = true

  //5.initialize：初始化中间结果，例如求和函数,开始计算前需要先将中间结果赋值为0。
  override def initialize(buffer: MutableAggregationBuffer): Unit = {buffer(0)=""}

  //6.update(buffer: MutableAggregationBuffer, input: Row)：更新中间结果，input是dataframe的一行，buffer是整个分片遍历过来的中间结果。
  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    if(!buffer.getString(0).contains(input.getString(0)))
      buffer(0) = buffer.getString(0) + input.getString(0)
  }

  //7.merge(buffer1:MutableAggregationBuffer,buffer2:Row)：分片的合并，buffer2一个分片的中间结果，buffer1是整个合并过程的中间结果
  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    buffer1(0) = buffer1.getString(0) + buffer2.getString(0) + ","
  }

  //8.evaluate(buffer:Row)：返回函数结果，buffer是7的合并过程的中间结果buffer1遍历所有分片结束后的结果。
  override def evaluate(buffer: Row): Any = buffer.getString(0).substring(0,buffer.getString(0).length-1)
}
